from __future__ import print_function

import time

# from utils import list_images
import os
import h5py
import numpy as np
import matplotlib.pyplot as plt
from train import train,many_train
# from generate import generate
import scipy.ndimage
from utils import mkdir

BATCH_SIZE = 16	 # 24 这儿改过
EPOCHES = 3
LOGGING_STEP = 240
MODEL_SAVE_PATH = 'model/'

# 这儿设置数据集
f = h5py.File('D:\TrainData_size64/train.h5', 'r')
# # for key in f.keys():
# #   print(f[key].name)
sources = f['data'][:]
# sources = np.transpose(sources, (0, 3, 2, 1))

print(('\nBegin to train the network ...\n'))

# no loop,train once
# train(sources, MODEL_SAVE_PATH, EPOCHES, BATCH_SIZE, logging_period = LOGGING_STEP)

# set a loop on train so that I can train it many times once
many_train(sources, MODEL_SAVE_PATH, EPOCHES, BATCH_SIZE, logging_period = LOGGING_STEP)


































